package com.peng.leetcode.unionFind;

import java.util.ArrayList;
import java.util.List;

/**
 * HitBricks
 * <p>
 * 803. 打砖块
 * <p>
 * 有一个 m x n 的二元网格，其中 1 表示砖块，0 表示空白。砖块 稳定（不会掉落）的前提是：
 * <p>
 * 一块砖直接连接到网格的顶部，或者
 * 至少有一块相邻（4 个方向之一）砖块 稳定 不会掉落时
 * 给你一个数组 hits ，这是需要依次消除砖块的位置。每当消除 hits[i] = (rowi, coli) 位置上的砖块时，对应位置的砖块（若存在）会消失，然后其他的砖块可能因为这一消除操作而掉落。一旦砖块掉落，它会立即从网格中消失（即，它不会落在其他稳定的砖块上）。
 * <p>
 * 返回一个数组 result ，其中 result[i] 表示第 i 次消除操作对应掉落的砖块数目。
 * <p>
 * 注意，消除可能指向是没有砖块的空白位置，如果发生这种情况，则没有砖块掉落。
 * <p>
 * 链接：https://leetcode-cn.com/problems/bricks-falling-when-hit
 *
 * @author: lupeng6
 * @create: 2021/1/19 11:10
 */
public class HitBricks {

    public static void main(String[] args) {
//        int[][] grid = {{1, 0, 0, 0}, {1, 1, 1, 0}};
        int[][] grid = {{1, 0, 1}, {0, 0, 1}};
        int[][] hits = {{1, 0}, {0, 0}};
        int[] ints = new HitBricks().hitBricks(grid, hits);
        for (int anInt : ints) {
            System.out.println(anInt);
        }
    }

    // 前后左右
    private static int[][] operation = {{1, 0}, {-1, 0}, {0, 1}, {0, -1}};

    private int[][] grid;

    private int rows;

    private int cols;

    public int[] hitBricks(int[][] grid, int[][] hits) {
        this.rows = grid.length;
        this.cols = grid[0].length;
        this.grid = grid;

        // copy 原数据
        int[][] copy = new int[rows][cols];
        for (int i = 0; i < rows; i++) {
            for (int j = 0; j < cols; j++) {
                copy[i][j] = grid[i][j];
            }
        }

        // 将砖块打掉
        for (int[] hit : hits) {
            int x = hit[0];
            int y = hit[1];
            copy[x][y] = 0;
        }

        // size 网格的大小, 也表示虚拟的屋顶
        int size = rows * cols;
        UnionFind unionFind = new UnionFind(size + 1);

        // 建图
        for (int i = 0; i < cols; i++) {
            // 当顶部 为 砖块时输入并查集
            int x = copy[0][i];
            if (x == 1) {
                unionFind.union(i, size);
            }
        }

        // 其余的网格, 如果是砖块则向上 向左看下是否是砖块, 如果是的话输入并查集
        for (int i = 1; i < rows; i++) {
            for (int j = 0; j < cols; j++) {
                int xy = copy[i][j];
                if (xy != 1) {
                    continue;
                }
                int currIndex = getIndex(i, j);
                // 是砖块
                // 向上看
                int top = copy[i - 1][j];
                if (top == 1) {
                    int topIndex = getIndex(i - 1, j);
                    unionFind.union(topIndex, currIndex);
                }
                // 向左看
                if (j > 0) {
                    int left = copy[i][j - 1];
                    if (left == 1) {
                        int leftIndex = getIndex(i, j - 1);
                        unionFind.union(leftIndex, currIndex);
                    }
                }
            }
        }

        List<Integer> res = new ArrayList<>();

        // 按照 hits 逆序, 在 copy 中补回砖块, 将每一次补回的砖块能连接到屋顶的砖块的增量记录到 res 中
        for (int i = hits.length - 1; i >= 0; i--) {
            int[] hit = hits[i];
            int x = hit[0];
            int y = hit[1];
            if (grid[x][y] == 1) {
                res.add(0);
                continue;
            }
            // 连接到屋顶的砖块数量
            int origin = unionFind.getSize(size);

            // 如果是第一行, 输入并查集
            if (x == 0) {
                unionFind.union(y, size);
            }

            // 在4个方向上查找合并
            for (int[] operate : operation) {
                int newX = x + operate[0];
                int newY = y + operate[1];
                boolean area = isArea(newX, newY);
                if (!area) {
                    continue;
                }
                int newXy = copy[newX][newY];
                if (newXy == 1) {
                    int currIndex = getIndex(x, y);
                    int newIndex = getIndex(newX, newY);
                    unionFind.union(currIndex, newIndex);
                }
            }

            int currSize = unionFind.getSize(size);
            int addSize = Math.max(0, currSize - origin - 1);
            res.add(addSize);

            // 补上砖块
            copy[x][y] = 1;
        }

        int index = 0;
        int[] result = new int[res.size()];
        for (int i = res.size() - 1; i >= 0; i--) {
            result[index] = res.get(i);
            index++;
        }

        return result;
    }


    static class UnionFind {
        // 当前节点的父节点
        private int[] parent;
        // 以当前节点为根节点的子树节点数量
        private int[] size;

        public UnionFind(int n) {
            this.parent = new int[n];
            this.size = new int[n];
            for (int i = 0; i < n; i++) {
                parent[i] = i;
                size[i] = 1;
            }
        }

        public int find(int x) {
            if (x != parent[x]) {
                parent[x] = find(parent[x]);
                x = parent[x];
            }
            return x;
        }

        public void union(int x, int y) {
            int xRoot = find(x);
            int yRoot = find(y);
            if (xRoot != yRoot) {
                parent[xRoot] = yRoot;

                // 将以 x 为根节点的数量 更新
                int yRootSize = size[yRoot];
                int xRootSize = size[xRoot];
                size[yRoot] = yRootSize + xRootSize;
            }
        }

        // 获取 节点 x 的root节点, 挂载的节点数量
        public int getSize(int x) {
            int root = find(x);
            return size[root];
        }
    }

    private int getIndex(int x, int y) {
        return x * cols + y;
    }


    private boolean isArea(int x, int y) {
        if (x >= 0 && x < this.grid.length) {
            if (y >= 0 && y < this.grid[0].length) {
                return true;
            }
        }
        return false;
    }


    public static class Solution {

        private int rows;
        private int cols;

        public static final int[][] DIRECTIONS = {{0, 1}, {1, 0}, {-1, 0}, {0, -1}};

        public int[] hitBricks(int[][] grid, int[][] hits) {
            this.rows = grid.length;
            this.cols = grid[0].length;

            // 第 1 步：把 grid 中的砖头全部击碎，通常算法问题不能修改输入数据，这一步非必需，可以认为是一种答题规范
            int[][] copy = new int[rows][cols];
            for (int i = 0; i < rows; i++) {
                for (int j = 0; j < cols; j++) {
                    copy[i][j] = grid[i][j];
                }
            }

            // 把 copy 中的砖头全部击碎
            for (int[] hit : hits) {
                copy[hit[0]][hit[1]] = 0;
            }

            // 第 2 步：建图，把砖块和砖块的连接关系输入并查集，size 表示二维网格的大小，也表示虚拟的「屋顶」在并查集中的编号
            int size = rows * cols;
            UnionFind unionFind = new UnionFind(size + 1);

            // 将下标为 0 的这一行的砖块与「屋顶」相连
            for (int j = 0; j < cols; j++) {
                if (copy[0][j] == 1) {
                    unionFind.union(j, size);
                }
            }

            // 其余网格，如果是砖块向上、向左看一下，如果也是砖块，在并查集中进行合并
            for (int i = 1; i < rows; i++) {
                for (int j = 0; j < cols; j++) {
                    if (copy[i][j] == 1) {
                        // 如果上方也是砖块
                        if (copy[i - 1][j] == 1) {
                            unionFind.union(getIndex(i - 1, j), getIndex(i, j));
                        }
                        // 如果左边也是砖块
                        if (j > 0 && copy[i][j - 1] == 1) {
                            unionFind.union(getIndex(i, j - 1), getIndex(i, j));
                        }
                    }
                }
            }

            // 第 3 步：按照 hits 的逆序，在 copy 中补回砖块，把每一次因为补回砖块而与屋顶相连的砖块的增量记录到 res 数组中
            int hitsLen = hits.length;
            int[] res = new int[hitsLen];
            for (int i = hitsLen - 1; i >= 0; i--) {
                int x = hits[i][0];
                int y = hits[i][1];

                // 注意：这里不能用 copy，语义上表示，如果原来在 grid 中，这一块是空白，这一步不会产生任何砖块掉落
                // 逆向补回的时候，与屋顶相连的砖块数量也肯定不会增加
                if (grid[x][y] == 0) {
                    continue;
                }

                // 补回之前与屋顶相连的砖块数
                int origin = unionFind.getSize(size);

                // 注意：如果补回的这个结点在第 1 行，要告诉并查集它与屋顶相连（逻辑同第 2 步）
                if (x == 0) {
                    unionFind.union(y, size);
                }

                // 在 4 个方向上看一下，如果相邻的 4 个方向有砖块，合并它们
                for (int[] direction : DIRECTIONS) {
                    int newX = x + direction[0];
                    int newY = y + direction[1];

                    if (inArea(newX, newY) && copy[newX][newY] == 1) {
                        unionFind.union(getIndex(x, y), getIndex(newX, newY));
                    }
                }

                // 补回之后与屋顶相连的砖块数
                int current = unionFind.getSize(size);
                // 减去的 1 是逆向补回的砖块（正向移除的砖块），与 0 比较大小，是因为存在一种情况，添加当前砖块，不会使得与屋顶连接的砖块数更多
                res[i] = Math.max(0, current - origin - 1);

                // 真正补上这个砖块
                copy[x][y] = 1;
            }
            return res;
        }

        /**
         * 输入坐标在二维网格中是否越界
         *
         * @param x
         * @param y
         * @return
         */
        private boolean inArea(int x, int y) {
            return x >= 0 && x < rows && y >= 0 && y < cols;
        }

        /**
         * 二维坐标转换为一维坐标
         *
         * @param x
         * @param y
         * @return
         */
        private int getIndex(int x, int y) {
            return x * cols + y;
        }

        private class UnionFind {

            /**
             * 当前结点的父亲结点
             */
            private int[] parent;
            /**
             * 以当前结点为根结点的子树的结点总数
             */
            private int[] size;

            public UnionFind(int n) {
                parent = new int[n];
                size = new int[n];
                for (int i = 0; i < n; i++) {
                    parent[i] = i;
                    size[i] = 1;
                }
            }

            /**
             * 路径压缩，只要求每个不相交集合的「根结点」的子树包含的结点总数数值正确即可，因此在路径压缩的过程中不用维护数组 size
             *
             * @param x
             * @return
             */
            public int find(int x) {
                if (x != parent[x]) {
                    parent[x] = find(parent[x]);
                }
                return parent[x];
            }

            public void union(int x, int y) {
                int rootX = find(x);
                int rootY = find(y);

                if (rootX == rootY) {
                    return;
                }
                parent[rootX] = rootY;
                // 在合并的时候维护数组 size
                size[rootY] += size[rootX];
            }

            /**
             * @param x
             * @return x 在并查集的根结点的子树包含的结点总数
             */
            public int getSize(int x) {
                int root = find(x);
                return size[root];
            }
        }
    }


}
